'''
Author: your name
Date: 1970-01-01 08:00:00
LastEditTime: 2020-12-10 10:03:45
LastEditors: Please set LastEditors
Description: In User Settings Edit
FilePath: /Pointnet_Pointnet2/models/my_loss.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

class get_loss(nn.Module):
    def __init__(self,args):
        super(get_loss, self).__init__()
        self.cls_num = args.num_class

    def forward(self, output, target):
        pred = F.sigmoid(output)
        y_oh = F.one_hot(target,self.cls_num)
        total_loss = -torch.sum(y_oh*torch.log(pred)+(1-y_oh)*torch.log(1-pred))
        return total_loss